Camera calibration

Estimating camera intrinsic parameters from checkerboard images.

Inputs: Images of checkerboard pattern (OpenCV sample data)
Outputs: Camera intrinsics (focal length, principal point, distortion coefficients)

Features used:

  • Var subclass for camera intrinsics

  • SE3Var for camera extrinsic poses

  • @jaxls.Cost.factory for reprojection error

  • OpenCV for chessboard corner detection

Hide code cell source

import sys
from loguru import logger

logger.remove()
logger.add(sys.stdout, format="<level>{level: <8}</level> | {message}");
import urllib.request
from pathlib import Path

import cv2
import jax
import jax.numpy as jnp
import jaxls
import jaxlie
import numpy as np
from scipy import ndimage

Download OpenCV sample images

Download the calibration images from the OpenCV repository:

def download_calibration_images(
    cache_dir: Path = Path("/tmp/opencv_calib"),
) -> list[Path]:
    """Download OpenCV sample calibration images.

    Args:
        cache_dir: Directory to cache downloaded images

    Returns:
        List of paths to downloaded image files
    """
    cache_dir.mkdir(parents=True, exist_ok=True)
    base_url = "https://raw.githubusercontent.com/opencv/opencv/master/samples/data"

    # Note: left10.jpg doesn't exist in the OpenCV repo
    image_indices = [1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14]

    image_paths = []
    for i in image_indices:
        filename = f"left{i:02d}.jpg"
        local_path = cache_dir / filename

        if not local_path.exists():
            url = f"{base_url}/{filename}"
            logger.info(f"Downloading {filename}...")
            urllib.request.urlretrieve(url, local_path)

        image_paths.append(local_path)

    return image_paths


image_paths = download_calibration_images()
print(f"Downloaded {len(image_paths)} calibration images")
Downloaded 13 calibration images

Detect chessboard corners

Use OpenCV to detect the inner corners of the 9x6 chessboard pattern:

# Chessboard parameters: 9x6 inner corners
board_cols, board_rows = 9, 6
square_size = 0.025  # 25mm squares (approximate)

# 3D checkerboard points (on Z=0 plane)
board_points_3d = np.zeros((board_rows * board_cols, 3), np.float32)
board_points_3d[:, :2] = (
    np.mgrid[0:board_cols, 0:board_rows].T.reshape(-1, 2) * square_size
)
board_points_3d = jnp.array(board_points_3d)

print(f"Chessboard: {board_cols}x{board_rows} = {len(board_points_3d)} corners")
print(
    f"Board size: {board_cols * square_size * 1000:.0f}mm x {board_rows * square_size * 1000:.0f}mm"
)
Chessboard: 9x6 = 54 corners
Board size: 225mm x 150mm
# Detect corners in all images
observations_2d: list[jax.Array] = []
valid_image_indices: list[int] = []
image_size = None

criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 30, 0.001)

for i, path in enumerate(image_paths):
    img = cv2.imread(str(path))
    if image_size is None:
        image_size = (img.shape[1], img.shape[0])  # (width, height)

    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    ret, corners = cv2.findChessboardCorners(gray, (board_cols, board_rows), None)

    if ret:
        # Refine corner positions
        corners_refined = cv2.cornerSubPix(gray, corners, (11, 11), (-1, -1), criteria)
        observations_2d.append(jnp.array(corners_refined.squeeze()))
        valid_image_indices.append(i)
        print(f"  Image {i + 1:2d}: ✓ Found {len(corners)} corners")
    else:
        print(f"  Image {i + 1:2d}: ✗ Chessboard not found")

num_views = len(observations_2d)
print(f"\nSuccessfully detected corners in {num_views}/{len(image_paths)} images")
print(f"Image size: {image_size[0]}x{image_size[1]}")
  Image  1: ✓ Found 54 corners
  Image  2: ✓ Found 54 corners
  Image  3: ✓ Found 54 corners
  Image  4: ✓ Found 54 corners
  Image  5: ✓ Found 54 corners
  Image  6: ✓ Found 54 corners
  Image  7: ✓ Found 54 corners
  Image  8: ✓ Found 54 corners
  Image  9: ✓ Found 54 corners
  Image 10: ✓ Found 54 corners
  Image 11: ✓ Found 54 corners
  Image 12: ✓ Found 54 corners
  Image 13: ✓ Found 54 corners

Successfully detected corners in 13/13 images
Image size: 640x480

Hide code cell source

import plotly.graph_objects as go
from plotly.subplots import make_subplots
from IPython.display import HTML

# Show sample input images with detected corners
sample_indices = [0, 3, 6]  # Show 3 sample images

fig_samples = make_subplots(
    rows=1,
    cols=len(sample_indices),
    subplot_titles=[f"Image {valid_image_indices[i] + 1}" for i in sample_indices],
)

for col, idx in enumerate(sample_indices):
    img = cv2.imread(str(image_paths[valid_image_indices[idx]]))
    img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    # Draw detected corners
    corners = np.array(observations_2d[idx])
    cv2.drawChessboardCorners(
        img_rgb, (board_cols, board_rows), corners.reshape(-1, 1, 2), True
    )

    fig_samples.add_trace(
        go.Image(z=img_rgb),
        row=1,
        col=col + 1,
    )

fig_samples.update_xaxes(showticklabels=False)
fig_samples.update_yaxes(showticklabels=False)
fig_samples.update_layout(
    height=280,
    margin=dict(t=40, b=20, l=20, r=20),
)
HTML(fig_samples.to_html(full_html=False, include_plotlyjs="cdn"))

Camera model

We use the Brown-Conrady distortion model (same as OpenCV):

\[x' = x_n (1 + k_1 r^2 + k_2 r^4) + 2 p_1 x_n y_n + p_2 (r^2 + 2 x_n^2)\]
\[y' = y_n (1 + k_1 r^2 + k_2 r^4) + p_1 (r^2 + 2 y_n^2) + 2 p_2 x_n y_n\]
\[u = f_x \cdot x' + c_x, \quad v = f_y \cdot y' + c_y\]

where \((x_n, y_n)\) are normalized coordinates, \(r^2 = x_n^2 + y_n^2\), \((k_1, k_2)\) are radial distortion coefficients, and \((p_1, p_2)\) are tangential distortion coefficients.

class IntrinsicsVar(
    jaxls.Var[jax.Array],
    default_factory=lambda: jnp.array([500.0, 500.0, 320.0, 240.0, 0.0, 0.0, 0.0, 0.0]),
):
    """Camera intrinsics: [fx, fy, cx, cy, k1, k2, p1, p2]."""


@jax.jit
def project_brown_conrady(
    points_camera: jax.Array,  # (N, 3) points in camera frame
    intrinsics: jax.Array,  # [fx, fy, cx, cy, k1, k2, p1, p2]
) -> jax.Array:
    """Project 3D points to 2D using Brown-Conrady distortion model.

    Args:
        points_camera: 3D points in camera frame (N, 3)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        2D projected points (N, 2)
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = intrinsics

    x, y, z = points_camera[..., 0], points_camera[..., 1], points_camera[..., 2]

    # Avoid division by zero
    z_safe = jnp.maximum(z, 1e-6)

    # Normalized coordinates
    x_n = x / z_safe
    y_n = y / z_safe

    # Radial distortion
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2

    # Tangential distortion
    x_d = x_n * radial + 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    y_d = y_n * radial + p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n

    # Pixel coordinates
    u = fx * x_d + cx
    v = fy * y_d + cy

    return jnp.stack([u, v], axis=-1)

Problem construction

We optimize camera intrinsics and all extrinsic poses jointly using reprojection error:

# Variables
intrinsics_var = IntrinsicsVar(id=0)
pose_vars = [jaxls.SE3Var(id=i) for i in range(num_views)]
@jaxls.Cost.factory
def reprojection_cost(
    vals: jaxls.VarValues,
    intrinsics_var: IntrinsicsVar,
    pose_var: jaxls.SE3Var,
    points_3d: jax.Array,  # (N, 3) batch of 3D points
    observed_2d: jax.Array,  # (N, 2) batch of observed 2D points
) -> jax.Array:
    """Reprojection error for a batch of points in a single view."""
    intrinsics = vals[intrinsics_var]
    pose = vals[pose_var]

    # Transform all points to camera frame and project
    points_camera = jax.vmap(pose.apply)(points_3d)
    projected = project_brown_conrady(points_camera, intrinsics)

    return (projected - observed_2d).flatten()
# Build costs using batched construction - one cost per view
costs: list[jaxls.Cost] = [
    reprojection_cost(
        intrinsics_var,
        pose_vars[view_idx],
        board_points_3d,  # All 3D points
        observations_2d[view_idx],  # All 2D observations for this view
    )
    for view_idx in range(num_views)
]

print(f"Created {len(costs)} batched reprojection costs ({num_views} views)")
Created 13 batched reprojection costs (13 views)
# Initialize intrinsics with reasonable guesses
# Focal length ~ image width, principal point ~ image center
init_fx = float(image_size[0]) / 2
init_fy = float(image_size[0]) / 2
init_cx = float(image_size[0]) / 2
init_cy = float(image_size[1]) / 2
init_intrinsics = jnp.array([init_fx, init_fy, init_cx, init_cy, 0.0, 0.0, 0.0, 0.0])

print(
    f"Initial intrinsics: fx={init_fx:.0f}, fy={init_fy:.0f}, cx={init_cx:.0f}, cy={init_cy:.0f}"
)
Initial intrinsics: fx=320, fy=320, cx=320, cy=240
def estimate_initial_pose(
    observed_corners: jax.Array, intrinsics: jax.Array
) -> jaxlie.SE3:
    """Estimate initial pose using OpenCV's solvePnP.

    Args:
        observed_corners: Detected 2D corner positions (N, 2)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        Estimated camera pose as SE3
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = intrinsics
    camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float64)
    dist_coeffs = np.array([k1, k2, p1, p2], dtype=np.float64)

    _, rvec, tvec = cv2.solvePnP(
        np.array(board_points_3d),
        np.array(observed_corners),
        camera_matrix,
        dist_coeffs,
    )

    R, _ = cv2.Rodrigues(rvec)
    rotation = jaxlie.SO3.from_matrix(jnp.array(R))
    translation = jnp.array(tvec.squeeze())

    return jaxlie.SE3.from_rotation_and_translation(rotation, translation)


# Estimate initial poses
init_poses = [estimate_initial_pose(obs, init_intrinsics) for obs in observations_2d]
print(f"Estimated {len(init_poses)} initial poses using PnP")
Estimated 13 initial poses using PnP
# Create initial values
initial_vals = jaxls.VarValues.make(
    [intrinsics_var.with_value(init_intrinsics)]
    + [pose_vars[i].with_value(init_poses[i]) for i in range(num_views)]
)

# Create and analyze problem
problem = jaxls.LeastSquaresProblem(costs, [intrinsics_var] + pose_vars).analyze()
INFO     | Building optimization problem with 13 terms and 14 variables: 13 costs, 0 eq_zero, 0 leq_zero, 0 geq_zero
INFO     | Vectorizing group with 13 costs, 2 variables each: reprojection_cost

Solving

solution = problem.solve(initial_vals)
INFO     |  step #1: cost=11776.8213 lambd=0.0005 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #2: cost=11776.8213 lambd=0.0010 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #3: cost=11776.8213 lambd=0.0020 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #4: cost=11776.8213 lambd=0.0040 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #5: cost=11776.8213 lambd=0.0080 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #6: cost=11776.8213 lambd=0.0160 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #7: cost=11776.8213 lambd=0.0320 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #8: cost=11776.8213 lambd=0.0640 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #9: cost=11776.8213 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |  step #10: cost=11776.8213 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 11776.82129 (avg 8.38805)
INFO     |      accepted=True ATb_norm=3.08e+04 cost_prev=11776.8213 cost_new=7542.5205
INFO     |  step #11: cost=7542.5205 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 7542.52051 (avg 5.37217)
INFO     |      accepted=True ATb_norm=6.22e+05 cost_prev=7542.5205 cost_new=4098.8364
INFO     |  step #12: cost=4098.8364 lambd=0.0640 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 4098.83643 (avg 2.91940)
INFO     |      accepted=True ATb_norm=5.82e+05 cost_prev=4098.8364 cost_new=1063.1411
INFO     |  step #13: cost=1063.1411 lambd=0.0320 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 1063.14111 (avg 0.75722)
INFO     |      accepted=True ATb_norm=2.82e+05 cost_prev=1063.1411 cost_new=345.1826
INFO     |  step #14: cost=345.1826 lambd=0.0160 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 345.18265 (avg 0.24586)
INFO     |  step #15: cost=345.1826 lambd=0.0320 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 345.18265 (avg 0.24586)
INFO     |  step #16: cost=345.1826 lambd=0.0640 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 345.18265 (avg 0.24586)
INFO     |  step #17: cost=345.1826 lambd=0.1280 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 345.18265 (avg 0.24586)
INFO     |  step #18: cost=345.1826 lambd=0.2560 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 345.18265 (avg 0.24586)
INFO     |  step #19: cost=345.1826 lambd=0.5120 inexact_tol=1.0e-02
INFO     |      - reprojection_cost(13): 345.18265 (avg 0.24586)
INFO     |      accepted=True ATb_norm=4.79e+03 cost_prev=345.1826 cost_new=272.9070
INFO     |  step #20: cost=272.9070 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 272.90695 (avg 0.19438)
INFO     |  step #21: cost=272.9070 lambd=0.5120 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 272.90695 (avg 0.19438)
INFO     |      accepted=True ATb_norm=8.49e+03 cost_prev=272.9070 cost_new=227.6678
INFO     |  step #22: cost=227.6678 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 227.66782 (avg 0.16216)
INFO     |      accepted=True ATb_norm=4.25e+03 cost_prev=227.6678 cost_new=178.2618
INFO     |  step #23: cost=178.2618 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 178.26176 (avg 0.12697)
INFO     |  step #24: cost=178.2618 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 178.26176 (avg 0.12697)
INFO     |      accepted=True ATb_norm=8.40e+03 cost_prev=178.2618 cost_new=151.6444
INFO     |  step #25: cost=151.6444 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 151.64439 (avg 0.10801)
INFO     |  step #26: cost=151.6444 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 151.64439 (avg 0.10801)
INFO     |      accepted=True ATb_norm=4.14e+03 cost_prev=151.6444 cost_new=137.1563
INFO     |  step #27: cost=137.1563 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 137.15631 (avg 0.09769)
INFO     |  step #28: cost=137.1563 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 137.15631 (avg 0.09769)
INFO     |      accepted=True ATb_norm=2.28e+03 cost_prev=137.1563 cost_new=129.0163
INFO     |  step #29: cost=129.0163 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 129.01627 (avg 0.09189)
INFO     |  step #30: cost=129.0163 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 129.01627 (avg 0.09189)
INFO     |      accepted=True ATb_norm=2.07e+03 cost_prev=129.0163 cost_new=124.3110
INFO     |  step #31: cost=124.3110 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 124.31096 (avg 0.08854)
INFO     |  step #32: cost=124.3110 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 124.31096 (avg 0.08854)
INFO     |      accepted=True ATb_norm=1.16e+03 cost_prev=124.3110 cost_new=121.5538
INFO     |  step #33: cost=121.5538 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 121.55382 (avg 0.08658)
INFO     |  step #34: cost=121.5538 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 121.55382 (avg 0.08658)
INFO     |      accepted=True ATb_norm=6.30e+02 cost_prev=121.5538 cost_new=119.9181
INFO     |  step #35: cost=119.9181 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 119.91814 (avg 0.08541)
INFO     |  step #36: cost=119.9181 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 119.91814 (avg 0.08541)
INFO     |      accepted=True ATb_norm=3.49e+02 cost_prev=119.9181 cost_new=118.9377
INFO     |  step #37: cost=118.9377 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 118.93771 (avg 0.08471)
INFO     |  step #38: cost=118.9377 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 118.93771 (avg 0.08471)
INFO     |      accepted=True ATb_norm=1.91e+02 cost_prev=118.9377 cost_new=118.3439
INFO     |  step #39: cost=118.3439 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 118.34389 (avg 0.08429)
INFO     |  step #40: cost=118.3439 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 118.34389 (avg 0.08429)
INFO     |      accepted=True ATb_norm=1.11e+02 cost_prev=118.3439 cost_new=117.9819
INFO     |  step #41: cost=117.9819 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.98190 (avg 0.08403)
INFO     |  step #42: cost=117.9819 lambd=0.2560 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.98190 (avg 0.08403)
INFO     |      accepted=True ATb_norm=5.83e+01 cost_prev=117.9819 cost_new=117.7601
INFO     |  step #43: cost=117.7601 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.76015 (avg 0.08387)
INFO     |      accepted=True ATb_norm=3.57e+01 cost_prev=117.7601 cost_new=117.5524
INFO     |  step #44: cost=117.5524 lambd=0.0640 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.55236 (avg 0.08373)
INFO     |  step #45: cost=117.5524 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.55236 (avg 0.08373)
INFO     |      accepted=True ATb_norm=5.77e+01 cost_prev=117.5524 cost_new=117.4660
INFO     |  step #46: cost=117.4660 lambd=0.0640 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.46597 (avg 0.08367)
INFO     |  step #47: cost=117.4660 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.46597 (avg 0.08367)
INFO     |      accepted=True ATb_norm=2.45e+01 cost_prev=117.4660 cost_new=117.4280
INFO     |  step #48: cost=117.4280 lambd=0.0640 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.42802 (avg 0.08364)
INFO     |  step #49: cost=117.4280 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.42802 (avg 0.08364)
INFO     |      accepted=True ATb_norm=1.12e+01 cost_prev=117.4280 cost_new=117.4125
INFO     |  step #50: cost=117.4125 lambd=0.0640 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.41252 (avg 0.08363)
INFO     |  step #51: cost=117.4125 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.41252 (avg 0.08363)
INFO     |      accepted=True ATb_norm=4.61e+00 cost_prev=117.4125 cost_new=117.4060
INFO     |  step #52: cost=117.4060 lambd=0.0640 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.40602 (avg 0.08362)
INFO     |  step #53: cost=117.4060 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.40602 (avg 0.08362)
INFO     |      accepted=True ATb_norm=7.86e+00 cost_prev=117.4060 cost_new=117.4029
INFO     |  step #54: cost=117.4029 lambd=0.0640 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.40290 (avg 0.08362)
INFO     |  step #55: cost=117.4029 lambd=0.1280 inexact_tol=2.6e-04
INFO     |      - reprojection_cost(13): 117.40290 (avg 0.08362)
INFO     |      accepted=True ATb_norm=7.81e+00 cost_prev=117.4029 cost_new=117.4020
INFO     | Terminated @ iteration #55: cost=117.4020 criteria=[1 0 0 0], term_deltas=7.9e-06,3.0e+00,7.9e-05
# Compare results
est_intrinsics = solution[intrinsics_var]

print("Estimated intrinsics:")
param_names = ["fx", "fy", "cx", "cy", "k1", "k2", "p1", "p2"]
print(f"  {'Parameter':<12} {'Initial':>12} {'Estimated':>12}")
print(f"  {'-' * 38}")
for i, name in enumerate(param_names):
    init, est = init_intrinsics[i], est_intrinsics[i]
    print(f"  {name:<12} {float(init):>12.4f} {float(est):>12.4f}")
Estimated intrinsics:
  Parameter         Initial    Estimated
  --------------------------------------
  fx               320.0000     536.3744
  fy               320.0000     536.3217
  cx               320.0000     342.3835
  cy               240.0000     235.5228
  k1                 0.0000      -0.2785
  k2                 0.0000       0.0668
  p1                 0.0000       0.0018
  p2                 0.0000      -0.0003

Visualization

Hide code cell source

def compute_reprojection_errors(
    intrinsics: jax.Array, poses: list[jaxlie.SE3]
) -> tuple[list[jax.Array], list[jax.Array]]:
    """Compute reprojection errors for all views.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        poses: List of camera poses (one per view)

    Returns:
        Tuple of (projected_points, errors) where each is a list per view
    """
    all_projected = []
    all_errors = []
    for i, pose in enumerate(poses):
        points_camera = jax.vmap(pose.apply)(board_points_3d)
        projected = project_brown_conrady(points_camera, intrinsics)
        errors = jnp.linalg.norm(projected - observations_2d[i], axis=-1)
        all_projected.append(projected)
        all_errors.append(errors)
    return all_projected, all_errors


# Compute errors before and after
init_projected, init_errors = compute_reprojection_errors(init_intrinsics, init_poses)
est_poses = [solution[pose_vars[i]] for i in range(num_views)]
est_projected, est_errors = compute_reprojection_errors(est_intrinsics, est_poses)

init_rmse = float(jnp.sqrt(jnp.mean(jnp.concatenate([e**2 for e in init_errors]))))
est_rmse = float(jnp.sqrt(jnp.mean(jnp.concatenate([e**2 for e in est_errors]))))
print(
    f"Reprojection RMSE: {init_rmse:.3f} px (initial) -> {est_rmse:.3f} px (optimized)"
)
Reprojection RMSE: 4.096 px (initial) -> 0.409 px (optimized)

Hide code cell source

# Reprojection error distribution
all_init_errors = jnp.concatenate(init_errors)
all_est_errors = jnp.concatenate(est_errors)

fig_errors = go.Figure()
fig_errors.add_trace(
    go.Histogram(
        x=all_init_errors,
        name="Initial",
        marker_color="tomato",
        opacity=0.7,
        nbinsx=30,
    )
)
fig_errors.add_trace(
    go.Histogram(
        x=all_est_errors,
        name="Optimized",
        marker_color="steelblue",
        opacity=0.7,
        nbinsx=30,
    )
)
fig_errors.update_layout(
    title="Reprojection Error Distribution",
    xaxis_title="Error (pixels)",
    yaxis_title="Count",
    barmode="overlay",
    height=300,
    margin=dict(t=40, b=40, l=60, r=40),
    legend=dict(yanchor="top", y=0.99, xanchor="right", x=0.99),
)
HTML(fig_errors.to_html(full_html=False, include_plotlyjs="cdn"))

Hide code cell source

# Camera poses (top-down view)
cam_positions = [pose.inverse().translation() for pose in est_poses]
cam_x = [float(p[0]) for p in cam_positions]
cam_y = [float(p[1]) for p in cam_positions]

# Chessboard outline
board_corners = jnp.array(
    [
        [0, 0, 0],
        [board_cols * square_size, 0, 0],
        [board_cols * square_size, board_rows * square_size, 0],
        [0, board_rows * square_size, 0],
        [0, 0, 0],
    ]
)

fig_poses = go.Figure()
fig_poses.add_trace(
    go.Scatter(
        x=board_corners[:, 0] * 1000,
        y=board_corners[:, 1] * 1000,
        mode="lines",
        line=dict(color="gray", width=2),
        name="Board",
    )
)
fig_poses.add_trace(
    go.Scatter(
        x=[c * 1000 for c in cam_x],
        y=[c * 1000 for c in cam_y],
        mode="markers+text",
        marker=dict(size=10, color="steelblue"),
        text=[str(i + 1) for i in range(num_views)],
        textposition="top center",
        name="Cameras",
    )
)
fig_poses.update_layout(
    title="Camera Poses (top view)",
    xaxis_title="X (mm)",
    yaxis_title="Y (mm)",
    xaxis=dict(scaleanchor="y", scaleratio=1),
    height=350,
    margin=dict(t=40, b=40, l=60, r=40),
    showlegend=False,
)
HTML(fig_poses.to_html(full_html=False, include_plotlyjs="cdn"))

Hide code cell source

# Single view comparison: initial vs optimized reprojection
view_idx = 1
obs = observations_2d[view_idx]

fig_view = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=(
        f"Initial (RMSE={float(jnp.sqrt(jnp.mean(init_errors[view_idx] ** 2))):.2f}px)",
        f"Optimized (RMSE={float(jnp.sqrt(jnp.mean(est_errors[view_idx] ** 2))):.2f}px)",
    ),
)

# Initial projection
fig_view.add_trace(
    go.Scatter(
        x=obs[:, 0],
        y=obs[:, 1],
        mode="markers",
        marker=dict(size=8, color="green", symbol="circle"),
        name="Observed",
        showlegend=True,
    ),
    row=1,
    col=1,
)
fig_view.add_trace(
    go.Scatter(
        x=init_projected[view_idx][:, 0],
        y=init_projected[view_idx][:, 1],
        mode="markers",
        marker=dict(size=6, color="tomato", symbol="x"),
        name="Projected",
        showlegend=True,
    ),
    row=1,
    col=1,
)
for j in range(len(obs)):
    fig_view.add_trace(
        go.Scatter(
            x=[obs[j, 0], init_projected[view_idx][j, 0]],
            y=[obs[j, 1], init_projected[view_idx][j, 1]],
            mode="lines",
            line=dict(color="tomato", width=0.5),
            showlegend=False,
            hoverinfo="skip",
        ),
        row=1,
        col=1,
    )

# Optimized projection
fig_view.add_trace(
    go.Scatter(
        x=obs[:, 0],
        y=obs[:, 1],
        mode="markers",
        marker=dict(size=8, color="green", symbol="circle"),
        showlegend=False,
    ),
    row=1,
    col=2,
)
fig_view.add_trace(
    go.Scatter(
        x=est_projected[view_idx][:, 0],
        y=est_projected[view_idx][:, 1],
        mode="markers",
        marker=dict(size=6, color="steelblue", symbol="x"),
        showlegend=False,
    ),
    row=1,
    col=2,
)
for j in range(len(obs)):
    fig_view.add_trace(
        go.Scatter(
            x=[obs[j, 0], est_projected[view_idx][j, 0]],
            y=[obs[j, 1], est_projected[view_idx][j, 1]],
            mode="lines",
            line=dict(color="steelblue", width=0.5),
            showlegend=False,
            hoverinfo="skip",
        ),
        row=1,
        col=2,
    )

fig_view.update_xaxes(title_text="u (pixels)")
fig_view.update_yaxes(title_text="v (pixels)", autorange="reversed")
fig_view.update_layout(
    height=400,
    margin=dict(t=40, b=80, l=60, r=40),
    legend=dict(orientation="h", yanchor="top", y=-0.15, xanchor="center", x=0.5),
)
HTML(fig_view.to_html(full_html=False, include_plotlyjs="cdn"))

Undistortion

Apply the estimated distortion parameters to rectify the images:

Hide code cell source

def undistort_image(img: np.ndarray, intrinsics: jax.Array) -> np.ndarray:
    """Undistort an image using the estimated intrinsics and scipy.ndimage.map_coordinates.

    Args:
        img: Input image (H, W, 3) or (H, W)
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]

    Returns:
        Undistorted image with same shape as input
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    h, w = img.shape[:2]

    # Create grid of output pixel coordinates (in undistorted image)
    u, v = np.meshgrid(np.arange(w), np.arange(h))

    # Convert to undistorted normalized coordinates
    x_n = (u - cx) / fx
    y_n = (v - cy) / fy

    # Apply forward distortion to find where to sample from in the distorted input
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    dx_t = 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    dy_t = p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n
    x_d = x_n * radial + dx_t
    y_d = y_n * radial + dy_t

    # Convert to pixel coordinates in the distorted input image
    u_src = fx * x_d + cx
    v_src = fy * y_d + cy

    # Sample from source image using map_coordinates
    if len(img.shape) == 3:
        # Color image - process each channel
        undistorted = np.zeros_like(img)
        for c in range(3):
            undistorted[:, :, c] = ndimage.map_coordinates(
                img[:, :, c], [v_src, u_src], order=1, mode="constant", cval=0
            )
    else:
        undistorted = ndimage.map_coordinates(
            img, [v_src, u_src], order=1, mode="constant", cval=0
        )

    return undistorted


def compute_distortion_at_points(
    intrinsics: jax.Array, points: jax.Array
) -> np.ndarray:
    """Compute distortion magnitude at specific pixel locations.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        points: Pixel coordinates (N, 2)

    Returns:
        Distortion magnitude at each point (N,)
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    u, v = points[:, 0], points[:, 1]

    x_n = (u - cx) / fx
    y_n = (v - cy) / fy
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    x_d = x_n * radial
    y_d = y_n * radial
    u_d = fx * x_d + cx
    v_d = fy * y_d + cy

    return np.sqrt((u_d - u) ** 2 + (v_d - v) ** 2)


def compute_distortion_magnitude(
    intrinsics: jax.Array, shape: tuple[int, int]
) -> np.ndarray:
    """Compute per-pixel distortion magnitude in pixels.

    Args:
        intrinsics: Camera intrinsics [fx, fy, cx, cy, k1, k2, p1, p2]
        shape: Image shape (height, width)

    Returns:
        Distortion magnitude map (H, W) in pixels
    """
    fx, fy, cx, cy, k1, k2, p1, p2 = [float(x) for x in intrinsics]
    h, w = shape

    u, v = np.meshgrid(np.arange(w), np.arange(h))

    # Normalized coordinates (undistorted)
    x_n = (u - cx) / fx
    y_n = (v - cy) / fy

    # Apply distortion
    r2 = x_n**2 + y_n**2
    radial = 1.0 + k1 * r2 + k2 * r2**2
    dx_t = 2 * p1 * x_n * y_n + p2 * (r2 + 2 * x_n**2)
    dy_t = p1 * (r2 + 2 * y_n**2) + 2 * p2 * x_n * y_n

    x_d = x_n * radial + dx_t
    y_d = y_n * radial + dy_t

    # Convert back to pixels
    u_d = fx * x_d + cx
    v_d = fy * y_d + cy

    # Displacement magnitude
    return np.sqrt((u_d - u) ** 2 + (v_d - v) ** 2)


# Show original vs undistorted + distortion map
sample_idx = 1
img_orig = cv2.imread(str(image_paths[valid_image_indices[sample_idx]]))
img_orig_rgb = cv2.cvtColor(img_orig, cv2.COLOR_BGR2RGB)
img_undist = undistort_image(img_orig_rgb, est_intrinsics)
distortion_map = compute_distortion_magnitude(est_intrinsics, img_orig.shape[:2])

# Compute distortion at observation locations for this view
obs_distortion = compute_distortion_at_points(
    est_intrinsics, observations_2d[sample_idx]
)

fig_undist = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=(
        "Original",
        "Undistorted",
        f"Distortion map (corners: {obs_distortion.min():.1f}-{obs_distortion.max():.1f}px)",
    ),
)

fig_undist.add_trace(go.Image(z=img_orig_rgb), row=1, col=1)
fig_undist.add_trace(go.Image(z=img_undist), row=1, col=2)
fig_undist.add_trace(
    go.Heatmap(
        z=distortion_map,
        colorscale="Hot",
        showscale=True,
        colorbar=dict(title="px", len=0.8, x=1.02),
    ),
    row=1,
    col=3,
)
# Overlay observation locations on distortion map
fig_undist.add_trace(
    go.Scatter(
        x=observations_2d[sample_idx][:, 0],
        y=observations_2d[sample_idx][:, 1],
        mode="markers",
        marker=dict(size=4, color="cyan", symbol="circle"),
        showlegend=False,
        hovertemplate="Distortion: %{text:.1f}px<extra></extra>",
        text=obs_distortion,
    ),
    row=1,
    col=3,
)

fig_undist.update_xaxes(showticklabels=False)
fig_undist.update_yaxes(showticklabels=False, autorange="reversed", row=1, col=3)
fig_undist.update_layout(
    height=280,
    margin=dict(t=40, b=20, l=20, r=40),
)
HTML(fig_undist.to_html(full_html=False, include_plotlyjs="cdn"))

The optimization calibrated the camera from checkerboard images:

  • Top-left: Reprojection error distribution before (red) and after (blue) optimization

  • Top-right: Top-down view of estimated camera positions relative to the chessboard

  • Bottom: Single view comparison showing observed corners (green) vs projected (x markers)

For more details, see: